from langchain.document_loaders import UnstructuredURLLoader
from langchain.chains.summarize import load_summarize_chain
from langchain.callbacks import get_openai_callback
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.llms import ChatGLM
from langchain.chains import LLMChain

# # Create an LLMChain object with the prompt and ChatGLM model
# llm_chain = LLMChain(prompt=prompt, llm=llm)

# # Ask a question and get the response from ChatGLM
# question = "北京和上海两座城市有什么不同？"
# response = llm_chain.run(question)

def summarize_docs(docs, doc_url):
    print (f'You have {len(docs)} document(s) in your {doc_url} data')
    print(docs)
    print (f'There are {len(docs[0].page_content)} characters in your document')

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    split_docs = text_splitter.split_documents(docs)

    print (f'You have {len(split_docs)} split document(s)')

    # OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
    # llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY, model_name="text-davinci-003")
    # endpoint_url = "http://127.0.0.1:8000"
    # llm = ChatGLM(
    #     endpoint_url=endpoint_url,
    #     max_token=80000,
    #     history=[],
    #     top_p=0.9,
    #     model_kwargs={"sample_model_args": False},
    # )
    endpoint_url = "http://127.0.0.1:8000/v1"
    from langchain.chat_models import ChatOpenAI
    llm = ChatOpenAI(
            model_name="chatglm",
            openai_api_base=endpoint_url,
            openai_api_key="EMPTY",
            streaming=False,
        )
    chain = load_summarize_chain(llm, chain_type="map_reduce", verbose=False)


    # Create a prompt template for the question
    template = """{question}"""
    prompt = PromptTemplate(template=template, input_variables=["question"])

    # Create an LLMChain object with the prompt and ChatGLM model
    llm_chain = LLMChain(prompt=prompt, llm=llm)

    # # Ask a question and get the response from ChatGLM
    # question = "北京和上海两座城市有什么不同？"
    # response = llm_chain.run(question)

    # print(response)

    response = ""
    with get_openai_callback() as cb:
        response = chain.run(input_documents=split_docs)
        print(f"Total Tokens: {cb.total_tokens}")
        print(f"Prompt Tokens: {cb.prompt_tokens}")
        print(f"Completion Tokens: {cb.completion_tokens}")
        print(f"Successful Requests: {cb.successful_requests}")
        print(f"Total Cost (USD): ${cb.total_cost}")

    return response

url = "https://blog.csdn.net/Baoweijie12/article/details/121480639"
url = "https://edition.cnn.com/2023/04/13/business/delta-earnings/index.html"
res = summarize_docs(UnstructuredURLLoader(urls = [url]).load(), url)
print(res)

# import os
# from typing import List
# from ChatGLM3 import ChatGLM3

# from langchain.agents import load_tools
# from Tool.Weather import Weather
# from Tool.Calculator import Calculator
# from langchain.agents import initialize_agent
# from langchain.agents import AgentType

# MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')

# def run_tool(tools, llm, prompt_chain: List[str]):
#     loaded_tolls = []
#     for tool in tools:
#         if isinstance(tool, str):
#             loaded_tolls.append(load_tools([tool], llm=llm)[0])
#         else:
#             loaded_tolls.append(tool)
#     agent = initialize_agent(
#         loaded_tolls, llm,
#         agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
#         verbose=True,
#         handle_parsing_errors=True
#     )
#     for prompt in prompt_chain:
#         agent.run(prompt)


# if __name__ == "__main__":
#     llm = ChatGLM3()
#     llm.load_model(model_name_or_path=MODEL_PATH)

#     # arxiv: 单个工具调用示例 1
#     run_tool(["arxiv"], llm, [
#         "帮我查询GLM-130B相关工作"
#     ])

#     # weather: 单个工具调用示例 2
#     run_tool([Weather()], llm, [
#         "今天北京天气怎么样？",
#         "What's the weather like in Shanghai today",
#     ])

#     # calculator: 单个工具调用示例 3
#     run_tool([Calculator()], llm, [
#         "12345679乘以54等于多少？",
#         "3.14的3.14次方等于多少？",
#         "根号2加上根号三等于多少？",
#     ]),

#     # arxiv + weather + calculator: 多个工具结合调用
#     # run_tool([Calculator(), "arxiv", Weather()], llm, [
#     #     "帮我检索GLM-130B相关论文",
#     #     "今天北京天气怎么样？",
#     #     "根号3减去根号二再加上4等于多少？",
#     # ])
